{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 在整个测试集上运行一遍"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "from keras.preprocessing.image import ImageDataGenerator\n",
    "from keras.models import Model\n",
    "from keras.optimizers import Adam\n",
    "from keras import backend as K\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import os\n",
    "import cv2\n",
    "from losses import bce_dice_loss\n",
    "from rssegnet import RSSegVGGNet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "\n",
    "def adjust_data(img):\n",
    "    mean_ = np.array([77.79408427,  89.68194542, 101.41694189])\n",
    "    img = (img - mean_) / 255\n",
    "    return img\n",
    "\n",
    "def make_test_generator():\n",
    "    test_data_root = 'D:/Data/mc_data/test'  # 测试集存放路径\n",
    "\n",
    "    data_gen_args = dict()\n",
    "    test_image_datagen = ImageDataGenerator(**data_gen_args)\n",
    "    test_image_generator = test_image_datagen.flow_from_directory(\n",
    "        test_data_root + '/images',\n",
    "        target_size=(288,288),\n",
    "        class_mode=None,\n",
    "        batch_size=16,\n",
    "        color_mode='rgb',\n",
    "        shuffle=False)\n",
    "\n",
    "    print(len(test_image_generator.filenames))\n",
    "    bidx = 1\n",
    "    for raw_images in test_image_generator:\n",
    "        if bidx > np.ceil(60288/test_image_generator.batch_size):\n",
    "            break\n",
    "            \n",
    "        idx = (bidx-1) * test_image_generator.batch_size\n",
    "        filenames = test_image_generator.filenames[idx : idx + test_image_generator.batch_size]\n",
    "        images = adjust_data(raw_images)           \n",
    "        bidx += 1\n",
    "        yield filenames,images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "__________________________________________________________________________________________________\n",
      "Layer (type)                    Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      "image_input (InputLayer)        (None, 288, 288, 3)  0                                            \n",
      "__________________________________________________________________________________________________\n",
      "conv2d_1 (Conv2D)               (None, 288, 288, 64) 1792        image_input[0][0]                \n",
      "__________________________________________________________________________________________________\n",
      "conv2d_2 (Conv2D)               (None, 288, 288, 64) 36928       conv2d_1[0][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "max_pooling2d_1 (MaxPooling2D)  (None, 144, 144, 64) 0           conv2d_2[0][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "conv2d_3 (Conv2D)               (None, 144, 144, 128 73856       max_pooling2d_1[0][0]            \n",
      "__________________________________________________________________________________________________\n",
      "conv2d_4 (Conv2D)               (None, 144, 144, 128 147584      conv2d_3[0][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "max_pooling2d_2 (MaxPooling2D)  (None, 72, 72, 128)  0           conv2d_4[0][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "conv2d_5 (Conv2D)               (None, 72, 72, 256)  295168      max_pooling2d_2[0][0]            \n",
      "__________________________________________________________________________________________________\n",
      "conv2d_6 (Conv2D)               (None, 72, 72, 256)  590080      conv2d_5[0][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "conv2d_7 (Conv2D)               (None, 72, 72, 256)  590080      conv2d_6[0][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "max_pooling2d_3 (MaxPooling2D)  (None, 36, 36, 256)  0           conv2d_7[0][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "conv2d_8 (Conv2D)               (None, 36, 36, 512)  1180160     max_pooling2d_3[0][0]            \n",
      "__________________________________________________________________________________________________\n",
      "conv2d_9 (Conv2D)               (None, 36, 36, 512)  2359808     conv2d_8[0][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "conv2d_10 (Conv2D)              (None, 36, 36, 512)  2359808     conv2d_9[0][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "max_pooling2d_4 (MaxPooling2D)  (None, 18, 18, 512)  0           conv2d_10[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "conv2d_11 (Conv2D)              (None, 18, 18, 512)  2359808     max_pooling2d_4[0][0]            \n",
      "__________________________________________________________________________________________________\n",
      "conv2d_12 (Conv2D)              (None, 18, 18, 512)  2359808     conv2d_11[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "conv2d_13 (Conv2D)              (None, 18, 18, 512)  2359808     conv2d_12[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "max_pooling2d_5 (MaxPooling2D)  (None, 9, 9, 512)    0           conv2d_13[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "conv2d_14 (Conv2D)              (None, 9, 9, 512)    2359808     max_pooling2d_5[0][0]            \n",
      "__________________________________________________________________________________________________\n",
      "conv2d_transpose_1 (Conv2DTrans (None, 18, 18, 256)  1179904     conv2d_14[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "concatenate_1 (Concatenate)     (None, 18, 18, 768)  0           conv2d_transpose_1[0][0]         \n",
      "                                                                 conv2d_13[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "conv2d_15 (Conv2D)              (None, 18, 18, 512)  3539456     concatenate_1[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "conv2d_transpose_2 (Conv2DTrans (None, 36, 36, 256)  1179904     conv2d_15[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "concatenate_2 (Concatenate)     (None, 36, 36, 768)  0           conv2d_transpose_2[0][0]         \n",
      "                                                                 conv2d_10[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "conv2d_16 (Conv2D)              (None, 36, 36, 512)  3539456     concatenate_2[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "conv2d_transpose_3 (Conv2DTrans (None, 72, 72, 128)  589952      conv2d_16[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "concatenate_3 (Concatenate)     (None, 72, 72, 384)  0           conv2d_transpose_3[0][0]         \n",
      "                                                                 conv2d_7[0][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "conv2d_17 (Conv2D)              (None, 72, 72, 256)  884992      concatenate_3[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "conv2d_transpose_4 (Conv2DTrans (None, 144, 144, 64) 147520      conv2d_17[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "concatenate_4 (Concatenate)     (None, 144, 144, 192 0           conv2d_transpose_4[0][0]         \n",
      "                                                                 conv2d_4[0][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "conv2d_18 (Conv2D)              (None, 144, 144, 128 221312      concatenate_4[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "conv2d_transpose_5 (Conv2DTrans (None, 288, 288, 32) 36896       conv2d_18[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "concatenate_5 (Concatenate)     (None, 288, 288, 96) 0           conv2d_transpose_5[0][0]         \n",
      "                                                                 conv2d_2[0][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "conv2d_19 (Conv2D)              (None, 288, 288, 1)  865         concatenate_5[0][0]              \n",
      "==================================================================================================\n",
      "Total params: 28,394,753\n",
      "Trainable params: 28,394,753\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model = RSSegVGGNet.build()\n",
    "model.compile(loss=bce_dice_loss, optimizer=Adam(), metrics=['accuracy'])\n",
    "model.load_weights(\"building_seg_vgg16_bcedice_0829.h5\")\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Found 60288 images belonging to 1 classes.\n",
      "60288\n",
      "1000\n",
      "2000\n",
      "3000\n"
     ]
    }
   ],
   "source": [
    "output_dir = 'D:/Data/mc_data/test/predicted_masks_vgg16_bcedice/'\n",
    "\n",
    "batch_idx = 0\n",
    "# test数据集预测结果\n",
    "for filenames,images in make_test_generator():\n",
    "#     print(filenames)\n",
    "    predicts = model.predict(images)\n",
    "#     print(predicts.shape)\n",
    "    for idx,image in enumerate(predicts):\n",
    "        # 先转换类型再处理\n",
    "        image = (np.squeeze(image)*255).astype(np.uint8)\n",
    "        image = cv2.resize(image, (300,300))\n",
    "        # threshold = 0.5\n",
    "        image = cv2.threshold(image, 127, 255,cv2.THRESH_BINARY)[1]\n",
    "        fullname = output_dir + filenames[idx]\n",
    "        cv2.imwrite(fullname, image)\n",
    "    batch_idx += 1\n",
    "    if batch_idx % 1000 == 0:\n",
    "        print(batch_idx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
